{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a0d4bcd9-5886-4293-bc25-a5373efacc33",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import spacy\n",
    "import time\n",
    "import torch\n",
    "\n",
    "from anchor import anchor_text\n",
    "from transformers import pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8192f85d-af53-4a73-a5de-282d0557be9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "nlp = spacy.load('en_core_web_sm')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "37728754-72e6-450f-867f-2de9a9a61525",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n",
    "\n",
    "classifier = pipeline(\n",
    "            \"sentiment-analysis\",\n",
    "            model=\"siebert/sentiment-roberta-large-english\",\n",
    "            tokenizer=\"siebert/sentiment-roberta-large-english\",\n",
    "            top_k=1,\n",
    "            device=device\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c2a3f202-d836-4afd-9e56-3a10e40bd834",
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict_prob(texts):\n",
    "    preds = classifier(texts)\n",
    "    preds = np.array([0 if label[0]['label'] == 'NEGATIVE'\n",
    "                      else 1 for label in preds])\n",
    "    return preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "1d5c936b-52b6-4ea1-a033-07a16e6252a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "explainer = anchor_text.AnchorText(nlp, ['NEGATIVE', 'POSITIVE'], use_unk_distribution=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "09715568-21f0-444b-9258-943b5b696f98",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prediction: POSITIVE\n"
     ]
    }
   ],
   "source": [
    "text = 'The little mermaid is a good story.'\n",
    "pred = explainer.class_names[predict_prob([text])[0]]\n",
    "print('Prediction: %s' % pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f5b1348-39a7-413f-b53e-4a361dc32dbb",
   "metadata": {},
   "outputs": [],
   "source": [
    "exp = explainer.explain_instance(text, predict_prob, threshold=0.95)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "7ab98e06-1bcb-4f87-b889-8fd03aedff94",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Anchor: good AND a AND is\n",
      "Precision: 1.00\n"
     ]
    }
   ],
   "source": [
    "print('Anchor: %s' % (' AND '.join(exp.names())))\n",
    "print('Precision: %.2f' % exp.precision())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "482c8e8b-f15e-4ef6-b732-45ca387b0212",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "UNK UNK mermaid is a good story UNK\n",
      "The UNK UNK is a good story UNK\n",
      "UNK UNK UNK is a good UNK UNK\n",
      "The little mermaid is a good story UNK\n",
      "The UNK mermaid is a good UNK UNK\n",
      "The UNK UNK is a good story UNK\n",
      "The little mermaid is a good UNK UNK\n",
      "The little mermaid is a good story .\n",
      "The little UNK is a good story UNK\n",
      "The little mermaid is a good story .\n"
     ]
    }
   ],
   "source": [
    "print('\\n'.join([x[0] for x in exp.examples(only_same_prediction=True)]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "50b7d8fc-cd5d-4bea-a894-5bdd76464bb1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "print('\\n'.join([x[0] for x in exp.examples(only_different_prediction=True)]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "8fcc23c5-32e4-4fd8-8fb2-3284f3f04f60",
   "metadata": {},
   "outputs": [],
   "source": [
    "explainer = anchor_text.AnchorText(nlp, ['negative', 'positive'], use_unk_distribution=False)\n",
    "exp = explainer.explain_instance(text, predict_prob, threshold=0.95)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "8cce3381-0c4b-4fd0-95bd-683eae3ce5d6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The glad ##stone is a good story .\n",
      "The dough ##nut is a good bread !\n",
      "The rock track is a good selection .\n",
      "The copper ##head is a good method .\n",
      "The black knight is a good party !\n",
      "The open batsman is a good friend .\n",
      "The color matching is typically good visibility .\n",
      "The mineral ##ization is a good mineral .\n",
      "The The process is reasonably good distribution .\n",
      "The greatest strength is always good value :\n"
     ]
    }
   ],
   "source": [
    "print('\\n'.join([x[0] for x in exp.examples(only_same_prediction=True)]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "6dd48d38-02f0-4e38-ab18-78c369956a3a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The universal variety is slightly good reads :\n",
      "The little tower is a good ruin .\n",
      "The third reality is too good luck .\n",
      "The death threat is a good warning .\n",
      "The great divide is not good here !\n",
      "The population background is below good estimates .\n",
      "The little friend is a good liar ;\n",
      "The primary obstacle is finding good data ;\n"
     ]
    }
   ],
   "source": [
    "print('\\n'.join([x[0] for x in exp.examples(only_different_prediction=True)]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "chapter09",
   "language": "python",
   "name": "chapter09"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
